from typing import List, Iterator, Dict, Tuple, Any
from collections import defaultdict
import math
import numpy as np
import torch

from torch.utils.hooks import RemovableHandle
from torch import Tensor

from allennlp.predictors import Predictor
from allennlp.models import Model
from allennlp.data import DatasetReader, Instance
from allennlp.data import Batch
from allennlp.nn import util
from allennlp.common.util import JsonDict

from tools.utils import get_model_parm_nums, mode


class TextPredictor(Predictor):
    def __init__(self, model: Model, dataset_reader: DatasetReader):
        super(TextPredictor, self).__init__(model, dataset_reader)
        model.eval()
        self.model_size = self.get_model_size()

    def get_gradients(self,
                      instances: List[Instance]) -> Dict[str, Any]:
        embedding_gradients: List[Tensor] = []
        hooks: List[RemovableHandle] = self._register_embedding_gradient_hooks(embedding_gradients)

        dataset = Batch(instances)
        dataset.index_instances(self._model.vocab)
        model_input = util.move_to_device(dataset.as_tensor_dict(), self._model._get_prediction_device())

        self._model.train()

        outputs = self._model.decode(self._model(**model_input))
        loss = outputs['loss']

        self._model.zero_grad()
        loss.backward()
        self._model.eval()

        for hook in hooks:
            hook.remove()

        grad_dict = dict()
        for idx, grad in enumerate(embedding_gradients):
            key = 'grad_input_' + str(idx + 1)
            grad_dict[key] = grad.detach()

        return grad_dict

    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        sentence = json_dict["sentence"]
        label = json_dict['label']
        return self._dataset_reader.text_to_instance(sentence, label)

    def predict_many_json(self, inputs: List[JsonDict]) -> List[JsonDict]:
        instances = self._batch_json_to_instances(inputs)
        num_tokens = max([len(instance["sentence"].tokens) for instance in instances])
        if self.model_size > 10:
            batch_size = math.floor(7000 / num_tokens)

        else:
            batch_size = math.floor(50000 / num_tokens)
            # batch_size=len(instances)

        outputs = []
        for i in range(0, len(instances), batch_size):
            out = self.predict_batch_instance(instances[i:i + batch_size])
            outputs.extend(out)
        return outputs

    def get_model_size(self):
        return get_model_parm_nums(self._model) / 1e6

    def get_params_l2(self):
        regularization_loss = 0
        for param in self._model.parameters():
            regularization_loss += torch.sum(param ** 2)
        return regularization_loss / get_model_parm_nums(self._model)


class TextPredictorEnsembler:
    def __init__(self, preditors: List[TextPredictor]):
        self.predictors = preditors

    def predict_batch_json(self, inputs):
        outputs = self.ensemble('predict_batch_json', inputs)
        return outputs

    def predict_many_json(self, inputs):
        outputs = self.ensemble('predict_many_json', inputs)
        return outputs

    def predict_instance(self, instance):
        outputs = self.ensemble('predict_instance', instance)
        return outputs

    def predict_batch_instance(self, instances):
        outputs = self.ensemble('predict_batch_instance', instances)
        return outputs

    def ensemble(self, func, inputs):
        # weights = [0.085, 0.085, 0.085, 0.085, 0.33, 0.33]
        outputs = [getattr(f, func)(inputs) for f in self.predictors]
        ensemble_logits, ensemble_label = [], []
        for model_outputs in zip(*outputs):
            logits = np.stack([i['logits'] for i in model_outputs]).mean(axis=0)
            # logits = np.stack([w * np.array(i['logits']) for i, w in zip(model_outputs, weights)]).sum(axis=0)
            label = model_outputs[0]['gold']
            ensemble_logits.append(logits)
            ensemble_label.append(label)
        wrap_outputs = self.predictors[0]._model.warp_outputs(torch.tensor(ensemble_logits),
                                                              torch.tensor(ensemble_label))
        del wrap_outputs['loss']
        split_output = [dict() for _ in range(len(ensemble_label))]
        for name, output in wrap_outputs.items():
            for i, o in enumerate(output): \
                    split_output[i][name] = o.numpy()

        return split_output

    def get_model_size(self):
        return np.sum([f.get_model_size() for f in self.predictors])

    def ensemble2(self, func, input):
        ensemble_criterion = {'gold_prob': np.max,
                              'gold': np.mean,
                              'pred': mode,
                              'logits': np.mean,
                              'default': np.mean,
                              }
        outputs = [getattr(f, func)(input) for f in self.predictors]
        ensemble_outputs = []
        for data_out in zip(*outputs):
            out = defaultdict(list)
            new_out = {}
            for model_out in data_out:
                for k, v in model_out.items():
                    out[k].append(v)
            for k, v in out.items():
                if k in ensemble_criterion:
                    criterion = ensemble_criterion[k]
                else:
                    criterion = ensemble_criterion['default']
                new_out[k] = criterion(np.stack(v), axis=0)
            ensemble_outputs.append(new_out)
        return ensemble_outputs
